deepSSF next-step validation

Author

Scott Forrest

Published

February 11, 2025

Abstract

In this script, we will assess how well the deepSSF model predicts the next step in the observed trajectory.

Whilst the realism and emergent properties of simulated trajectories are difficult to assess, we can validate the deepSSF models on their predictive performance at the next step, for each of the habitat selection, movement and next-step probability surfaces. Ensuring that the probability surfaces are normalised to sum to one, they can be compared to the predictions of typical step selection functions when the same probability surfaces are generated for the same local covariates. This approach not only allows for comparison between models, but can be informative as to when in the observed trajectory the model performs well or poorly, which can be analysed across the entire tracking period or for each hour of the day, and can lead to critical evaluation of the covariates that are used by the models and allow for model refinement.

Import packages

Code
# If using Google Colab, uncomment the following line
# !pip install rasterio

import sys
print(sys.version)  # Print Python version in use

import numpy as np                                      # Array operations
import matplotlib.pyplot as plt                         # Plotting library
import torch                                            # Main PyTorch library
import torch.optim as optim                             # Optimization algorithms
import torch.nn as nn                                    # Neural network modules
import os                                               # Operating system utilities
import pandas as pd                                     # Data manipulation
import rasterio                                         # Geospatial raster data

from torch.utils.data import Dataset, DataLoader        # Dataset and batch data loading
from datetime import datetime                           # Date/time utilities
from rasterio.plot import show                           # Plot raster data

# Get today's date
today_date = datetime.today().strftime('%Y-%m-%d')

# Get today's date
today_date = datetime.today().strftime('%Y-%m-%d')
3.12.5 | packaged by Anaconda, Inc. | (main, Sep 12 2024, 18:18:29) [MSC v.1929 64 bit (AMD64)]

If using Google Colab, uncomment the following lines

The file directories will also need to be changed to match the location of the files in your Google Drive.

Code
# from google.colab import drive
# drive.mount('/content/drive')
Code
image_dim = 101
pixel_size = 25
center = image_dim // 2
y, x = np.indices((image_dim, image_dim))

distance_layer = np.sqrt((pixel_size*(x - center))**2 + (pixel_size*(y - center))**2)
# change the centre cell to the average distance from the centre to the edge of the pixel
distance_layer[center, center] = 0.56*pixel_size # average distance from the centre to the perimeter of the pixel (accounting for longer distances at the corners)
bearing_layer = np.arctan2(center - y, x - center)

# plot both of the layers
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(distance_layer)
ax[0].set_title('Distance Layer')
# add colour scale
cbar = plt.colorbar(ax[0].imshow(distance_layer), ax=ax[0])
ax[1].imshow(bearing_layer)
ax[1].set_title('Bearing Layer')
# add colour scale
cbar = plt.colorbar(ax[1].imshow(bearing_layer), ax=ax[1])
plt.show()

Import data

Code
# Get today's date
today_date = datetime.today().strftime('%Y-%m-%d')

# select the id to train the model on
buffalo_id = 2005
n_samples = 10297

# buffalo_id = 2014
# n_samples = 6572

# buffalo_id = 2327
# n_samples = 8983

# buffalo_id = 2387
# n_samples = 10409

# Specify the path to your CSV file
csv_file_path = f'../buffalo_local_data_id/validation/validation_buffalo_{buffalo_id}_data_df_lag_1hr_n{n_samples}.csv'

# Read the CSV file into a DataFrame
buffalo_df = pd.read_csv(csv_file_path)

# Display the first few rows of the DataFrame
print(buffalo_df.head())

# Lag the values in column 'A' by one index
buffalo_df['bearing_tm1'] = buffalo_df['bearing'].shift(1)
# Pad the missing value with a specified value, e.g., 0
buffalo_df['bearing_tm1'] = buffalo_df['bearing_tm1'].fillna(0)

# Display the first few rows of the DataFrame
print(buffalo_df.head())
             x_            y_                    t_    id           x1_  \
0  41969.310875 -1.435671e+06  2018-07-25T01:04:23Z  2005  41969.310875   
1  41921.521939 -1.435654e+06  2018-07-25T02:04:39Z  2005  41921.521939   
2  41779.439594 -1.435601e+06  2018-07-25T03:04:17Z  2005  41779.439594   
3  41841.203272 -1.435635e+06  2018-07-25T04:04:39Z  2005  41841.203272   
4  41655.463332 -1.435604e+06  2018-07-25T05:04:27Z  2005  41655.463332   

            y1_           x2_           y2_     x2_cent    y2_cent  ...  \
0 -1.435671e+06  41921.521939 -1.435654e+06  -47.788936  16.857110  ...   
1 -1.435654e+06  41779.439594 -1.435601e+06 -142.082345  53.568427  ...   
2 -1.435601e+06  41841.203272 -1.435635e+06   61.763677 -34.322938  ...   
3 -1.435635e+06  41655.463332 -1.435604e+06 -185.739939  31.003534  ...   
4 -1.435604e+06  41618.651923 -1.435608e+06  -36.811409  -4.438037  ...   

         ta    cos_ta         x_min         x_max         y_min         y_max  \
0  1.367942  0.201466  40706.810875  43231.810875 -1.436934e+06 -1.434409e+06   
1 -0.021429  0.999770  40659.021939  43184.021939 -1.436917e+06 -1.434392e+06   
2  2.994917 -0.989262  40516.939594  43041.939594 -1.436863e+06 -1.434338e+06   
3 -2.799767 -0.942144  40578.703272  43103.703272 -1.436898e+06 -1.434373e+06   
4  0.285377  0.959556  40392.963332  42917.963332 -1.436867e+06 -1.434342e+06   

   s2_index  points_vect_cent  year_t2  yday_t2_2018_base  
0         7               NaN     2018                206  
1         7               NaN     2018                206  
2         7               NaN     2018                206  
3         7               NaN     2018                206  
4         7               NaN     2018                206  

[5 rows x 35 columns]
             x_            y_                    t_    id           x1_  \
0  41969.310875 -1.435671e+06  2018-07-25T01:04:23Z  2005  41969.310875   
1  41921.521939 -1.435654e+06  2018-07-25T02:04:39Z  2005  41921.521939   
2  41779.439594 -1.435601e+06  2018-07-25T03:04:17Z  2005  41779.439594   
3  41841.203272 -1.435635e+06  2018-07-25T04:04:39Z  2005  41841.203272   
4  41655.463332 -1.435604e+06  2018-07-25T05:04:27Z  2005  41655.463332   

            y1_           x2_           y2_     x2_cent    y2_cent  ...  \
0 -1.435671e+06  41921.521939 -1.435654e+06  -47.788936  16.857110  ...   
1 -1.435654e+06  41779.439594 -1.435601e+06 -142.082345  53.568427  ...   
2 -1.435601e+06  41841.203272 -1.435635e+06   61.763677 -34.322938  ...   
3 -1.435635e+06  41655.463332 -1.435604e+06 -185.739939  31.003534  ...   
4 -1.435604e+06  41618.651923 -1.435608e+06  -36.811409  -4.438037  ...   

     cos_ta         x_min         x_max         y_min         y_max  s2_index  \
0  0.201466  40706.810875  43231.810875 -1.436934e+06 -1.434409e+06         7   
1  0.999770  40659.021939  43184.021939 -1.436917e+06 -1.434392e+06         7   
2 -0.989262  40516.939594  43041.939594 -1.436863e+06 -1.434338e+06         7   
3 -0.942144  40578.703272  43103.703272 -1.436898e+06 -1.434373e+06         7   
4  0.959556  40392.963332  42917.963332 -1.436867e+06 -1.434342e+06         7   

   points_vect_cent  year_t2  yday_t2_2018_base  bearing_tm1  
0               NaN     2018                206     0.000000  
1               NaN     2018                206     2.802478  
2               NaN     2018                206     2.781049  
3               NaN     2018                206    -0.507220  
4               NaN     2018                206     2.976198  

[5 rows x 36 columns]

Importing spatial data

Global layers

NDVI

Code
# for monthly NDVI
file_path = '../mapping/cropped rasters/ndvi_monthly.tif'
# read the raster file
with rasterio.open(file_path) as src:
    # Read the raster band as separate variable
    ndvi_global = src.read([i for i in range(1, src.count + 1)])
    # Get the metadata of the raster
    ndvi_meta = src.meta
    raster_transform = src.transform

    # Print the metadata to check for time component
    print("Metadata:", ndvi_meta)

    # Check for specific time-related metadata
    if 'TIFFTAG_DATETIME' in src.tags():
        print("Time component found:", src.tags()['TIFFTAG_DATETIME'])
    else:
        print("No explicit time component found in metadata.")

# the rasters don't contain a time component
Metadata: {'driver': 'GTiff', 'dtype': 'float32', 'nodata': nan, 'width': 2400, 'height': 2280, 'count': 24, 'crs': CRS.from_wkt('LOCAL_CS["GDA94 / Geoscience Australia Lambert",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]'), 'transform': Affine(25.0, 0.0, 0.0,
       0.0, -25.0, -1406000.0)}
No explicit time component found in metadata.
Code
print(ndvi_meta)
print(raster_transform)
print(ndvi_global.shape)

# Replace NaNs in the original array with -1, which represents water
ndvi_global = np.nan_to_num(ndvi_global, nan=-1.0)

# from the stack of local layers
ndvi_max = 0.8220
ndvi_min = -0.2772

ndvi_global_tens = torch.from_numpy(ndvi_global)

# Normalizing the data
ndvi_global_norm = (ndvi_global_tens - ndvi_min) / (ndvi_max - ndvi_min)

# plt.imshow(ndvi_global_norm.numpy())
# plt.colorbar()  
# plt.show()

plt.imshow(ndvi_global_norm[1,:,:].numpy())
plt.colorbar()  
plt.show()

# plt.imshow(ndvi_global_norm[8,:,:].numpy())
# plt.colorbar()  
# plt.show()
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': nan, 'width': 2400, 'height': 2280, 'count': 24, 'crs': CRS.from_wkt('LOCAL_CS["GDA94 / Geoscience Australia Lambert",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]'), 'transform': Affine(25.0, 0.0, 0.0,
       0.0, -25.0, -1406000.0)}
| 25.00, 0.00, 0.00|
| 0.00,-25.00,-1406000.00|
| 0.00, 0.00, 1.00|
(24, 2280, 2400)

Canopy cover

Code
file_path = '../mapping/cropped rasters/canopy_cover.tif'
# read the raster file
with rasterio.open(file_path) as src:
    # Read the raster band as separate variable
    canopy_global = src.read(1)
    # Get the metadata of the raster
    canopy_meta = src.meta
Code
print(canopy_meta)
print(canopy_global.shape)

# from the stack of local layers
canopy_max = 82.5000
canopy_min = 0.0

canopy_global_tens = torch.from_numpy(canopy_global)

# Normalizing the data
canopy_global_norm = (canopy_global_tens - canopy_min) / (canopy_max - canopy_min)

plt.imshow(canopy_global_norm.numpy())
plt.colorbar()  
plt.show()
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': -3.3999999521443642e+38, 'width': 2400, 'height': 2280, 'count': 1, 'crs': CRS.from_wkt('LOCAL_CS["GDA94 / Geoscience Australia Lambert",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]'), 'transform': Affine(25.0, 0.0, 0.0,
       0.0, -25.0, -1406000.0)}
(2280, 2400)

Herbaceous vegetation

Code
file_path = '../mapping/cropped rasters/veg_herby.tif'
# read the raster file
with rasterio.open(file_path) as src:
    # Read the raster band as separate variable
    herby_global = src.read(1)
    # Get the metadata of the raster
    herby_meta = src.meta
Code
print(herby_meta)
print(herby_global.shape)

# from the stack of local layers
herby_max = 1.0
herby_min = 0.0

herby_global_tens = torch.from_numpy(herby_global)

# Normalizing the data
herby_global_norm = (herby_global_tens - herby_min) / (herby_max - herby_min)

plt.imshow(herby_global_norm.numpy())
plt.colorbar()  
plt.show()
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': -3.3999999521443642e+38, 'width': 2400, 'height': 2280, 'count': 1, 'crs': CRS.from_wkt('LOCAL_CS["GDA94 / Geoscience Australia Lambert",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]'), 'transform': Affine(25.0, 0.0, 0.0,
       0.0, -25.0, -1406000.0)}
(2280, 2400)

Slope

Code
file_path = '../mapping/cropped rasters/slope.tif'
# read the raster file
with rasterio.open(file_path) as src:
    # Read the raster band as separate variable
    slope_global = src.read(1)
    # Get the metadata of the raster
    slope_meta = src.meta
    slope_transform = src.transform # same as the raster transform in the NDVI raster read

print(slope_global)
print(slope_transform)
[[       nan        nan        nan ...        nan        nan        nan]
 [       nan 0.3352837  0.39781624 ... 0.         0.                nan]
 [       nan 0.3983888  0.48142004 ... 0.         0.                nan]
 ...
 [       nan 2.215875   1.9798415  ... 1.5078747  1.268342          nan]
 [       nan 1.9740707  1.7354656  ... 1.697194   1.4880029         nan]
 [       nan        nan        nan ...        nan        nan        nan]]
| 25.00, 0.00, 0.00|
| 0.00,-25.00,-1406000.00|
| 0.00, 0.00, 1.00|
Code
print(slope_meta)
print(slope_global.shape)

# Replace NaNs in the original array with -1, which represents water
slope_global = np.nan_to_num(slope_global, nan=0.0)

# from the stack of local layers
slope_max = 12.2981
slope_min = 0.0006

slope_global_tens = torch.from_numpy(slope_global)

# Normalizing the data
slope_global_norm = (slope_global_tens - slope_min) / (slope_max - slope_min)

plt.imshow(slope_global_norm.numpy())
plt.colorbar()  
plt.show()
{'driver': 'GTiff', 'dtype': 'float32', 'nodata': nan, 'width': 2400, 'height': 2280, 'count': 1, 'crs': CRS.from_wkt('LOCAL_CS["GDA94 / Geoscience Australia Lambert",UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]'), 'transform': Affine(25.0, 0.0, 0.0,
       0.0, -25.0, -1406000.0)}
(2280, 2400)

Global layers

Code
# Create a figure and axis with matplotlib
fig, ax = plt.subplots(figsize=(7.5, 7.5))

# Plot the raster
show(slope_global, transform=raster_transform, ax=ax, cmap='viridis')

# Set the title and labels
ax.set_title('Raster with Geographic Coordinates')
ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')

# Show the plot
plt.show()

Going between cell and raster coordinates

Code
x, y = 5.9e4, -1.447e6
print(x, y)

# Convert geographic coordinates to pixel coordinates
px, py = ~raster_transform * (x, y)
# Round pixel coordinates to integers
px, py = int(round(px)), int(round(py))

# Print the pixel coordinates   
print(px, py)
59000.0 -1447000.0
2360 1640

Subset function

Code
def subset_raster_with_padding_torch(raster_tensor, x, y, window_size, transform):
    # Convert geographic coordinates to pixel coordinates
    px, py = ~transform * (x, y)
    
    # Round pixel coordinates to integers
    # px, py = int(round(px)), int(round(py)) # this will go to the cell adjacent when above 0.5 - which is NOT what we want
    px, py = int(px), int(py) # this will always round towards zero, therefore staying in the same cell
    
    # Define half the window size
    half_window = window_size // 2
    
    # Calculate the window boundaries
    row_start = py - half_window
    row_stop = py + half_window + 1
    col_start = px - half_window
    col_stop = px + half_window + 1
    
    # Initialize the subset tensor with zeros (or any other padding value)
    subset = torch.full((window_size, window_size), -1.0, dtype=raster_tensor.dtype)
    
    # Calculate the valid region within the raster bounds
    valid_row_start = max(0, row_start)
    valid_row_stop = min(raster_tensor.shape[0], row_stop)
    valid_col_start = max(0, col_start)
    valid_col_stop = min(raster_tensor.shape[1], col_stop)
    
    # Calculate the corresponding region in the subset tensor
    subset_row_start = valid_row_start - row_start
    subset_row_stop = subset_row_start + (valid_row_stop - valid_row_start)
    subset_col_start = valid_col_start - col_start
    subset_col_stop = subset_col_start + (valid_col_stop - valid_col_start)
    
    # Copy the valid region from the raster tensor to the subset tensor
    subset[subset_row_start:subset_row_stop, subset_col_start:subset_col_stop] = \
        raster_tensor[valid_row_start:valid_row_stop, valid_col_start:valid_col_stop]
    
    return subset, col_start, row_start

Testing the subset function

Code
x = 5.9e4
y = -1.41e6
window_size = 101

which_ndvi = 1

ndvi_subset, origin_x, origin_y = subset_raster_with_padding_torch(ndvi_global_norm[which_ndvi,:,:], x, y, window_size, raster_transform)
canopy_subset, origin_x, origin_y = subset_raster_with_padding_torch(canopy_global_norm, x, y, window_size, raster_transform)
herby_subset, origin_x, origin_y = subset_raster_with_padding_torch(herby_global_norm, x, y, window_size, raster_transform)
slope_subset, origin_x, origin_y = subset_raster_with_padding_torch(slope_global_norm, x, y, window_size, raster_transform)

# Plot the subset
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

axs[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis')
axs[0, 0].set_title('NDVI Subset')
# axs[0, 0].axis('off')

axs[0, 1].imshow(canopy_subset.numpy(), cmap='viridis')
axs[0, 1].set_title('Canopy Cover Subset')
# axs[0, 1].axis('off')

axs[1, 0].imshow(herby_subset.numpy(), cmap='viridis')
axs[1, 0].set_title('Herbaceous Vegetation Subset')
# axs[1, 0].axis('off')

axs[1, 1].imshow(slope_subset.numpy(), cmap='viridis')
axs[1, 1].set_title('Slope Subset')
# axs[1, 1].axis('off')
Text(0.5, 1.0, 'Slope Subset')

Running the model on the subset layers

Set the device for the model

Code
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using {device} device")
Using cpu device

Define the model

Code
class Conv2d_block_toFC(nn.Module):
    def __init__(self, params):
        super(Conv2d_block_toFC, self).__init__()
        self.batch_size = params.batch_size
        self.input_channels = params.input_channels
        self.output_channels = params.output_channels
        self.kernel_size = params.kernel_size
        self.stride = params.stride
        self.kernel_size_mp = params.kernel_size_mp
        self.stride_mp = params.stride_mp
        self.padding = params.padding
        self.image_dim = params.image_dim
        self.device = params.device

        self.conv2d = nn.Sequential(
        nn.Conv2d(in_channels=self.input_channels, out_channels=self.output_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=self.kernel_size_mp, stride=self.stride_mp),
        nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=self.kernel_size_mp, stride=self.stride_mp),
        nn.Flatten())

    def forward(self, x):
        return self.conv2d(x)


class Conv2d_block_spatial(nn.Module):
    def __init__(self, params):
        super(Conv2d_block_spatial, self).__init__()
        self.batch_size = params.batch_size
        self.input_channels = params.input_channels
        self.output_channels = params.output_channels
        self.kernel_size = params.kernel_size
        self.stride = params.stride
        # self.kernel_size_mp = params.kernel_size_mp
        # self.stride_mp = params.stride_mp
        self.padding = params.padding
        self.image_dim = params.image_dim
        self.device = params.device

        self.conv2d = nn.Sequential(
        nn.Conv2d(in_channels=self.input_channels, out_channels=self.output_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.ReLU(),
        nn.Conv2d(in_channels=self.output_channels, out_channels=self.output_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.ReLU(),
        nn.Conv2d(in_channels=self.output_channels, out_channels=1, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        )

    def forward(self, x):
        # print("Shape before squeeze:", self.conv2d(x).shape)
        conv2d_spatial = self.conv2d(x).squeeze(dim = 1)
        # print("Shape before logsumexp:", conv2d_spatial.shape)

        # normalise before combining with the movement grid
        conv2d_spatial = conv2d_spatial - torch.logsumexp(conv2d_spatial, dim = (1, 2), keepdim = True)

        # conv2d_spatial = conv2d_spatial/torch.sum(conv2d_spatial)
        return conv2d_spatial


class FCN_block_all_habitat(nn.Module):
    def __init__(self, params):
        super(FCN_block_all_habitat, self).__init__()
        self.batch_size = params.batch_size
        self.dense_dim_in_all = params.dense_dim_in_all
        self.dense_dim_hidden = params.dense_dim_hidden
        self.dense_dim_out = params.dense_dim_out
        self.image_dim = params.image_dim
        self.device = params.device
        self.dropout = params.dropout

        self.ffn = nn.Sequential(
            nn.Linear(self.dense_dim_in_all, self.dense_dim_hidden),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.dense_dim_hidden, self.dense_dim_hidden),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.dense_dim_hidden, self.image_dim * self.image_dim)
        )

    def forward(self, x):
        return self.ffn(x)


class FCN_block_all_movement(nn.Module):
    def __init__(self, params):
        super(FCN_block_all_movement, self).__init__()
        self.batch_size = params.batch_size
        self.dense_dim_in_all = params.dense_dim_in_all
        self.dense_dim_hidden = params.dense_dim_hidden
        self.dense_dim_out = params.dense_dim_out
        self.image_dim = params.image_dim
        self.device = params.device
        self.num_movement_params = params.num_movement_params
        self.dropout = params.dropout

        self.ffn = nn.Sequential(
            nn.Linear(self.dense_dim_in_all, self.dense_dim_hidden),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.dense_dim_hidden, self.dense_dim_hidden),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.dense_dim_hidden, self.num_movement_params)
        )

    def forward(self, x):
        return self.ffn(x)

class FCN_block_nonspatial(nn.Module):
    def __init__(self, params):
        super(FCN_block_nonspatial, self).__init__()
        self.batch_size = params.batch_size
        self.dense_dim_in_nonspatial = params.dense_dim_in_nonspatial
        self.dense_dim_hidden = params.dense_dim_hidden
        self.dense_dim_out = params.dense_dim_out
        self.image_dim = params.image_dim
        self.device = params.device
        self.dropout = params.dropout

        self.ffn = nn.Sequential(
            nn.Linear(self.dense_dim_in_nonspatial, self.dense_dim_hidden),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.dense_dim_hidden, self.dense_dim_hidden),
            nn.Dropout(self.dropout),
            nn.ReLU(),
            nn.Linear(self.dense_dim_hidden, self.dense_dim_out)
        )

    def forward(self, x):
        return self.ffn(x)



##################################################
# Mixture of two Gamma and von Mises distributions with the von Mises mu parameters allowed to vary
##################################################

class Params_to_Grid_Block(nn.Module):
    def __init__(self, params):
        super(Params_to_Grid_Block, self).__init__()
        self.batch_size = params.batch_size
        self.image_dim = params.image_dim
        self.pixel_size = params.pixel_size
        self.center = self.image_dim // 2
        y, x = np.indices((self.image_dim, self.image_dim))
        self.distance_layer = torch.from_numpy(np.sqrt((self.pixel_size*(x - self.center))**2 + (self.pixel_size*(y - self.center))**2))
        # change the centre cell to the average distance from the centre to the edge of the pixel
        self.distance_layer[self.center, self.center] = 0.56*self.pixel_size # average distance from the centre to the perimeter of the pixel (accounting for longer distances at the corners)
        # self.bearing_layer = torch.from_numpy(np.arctan2(y - self.center, x - self.center))
        self.bearing_layer = torch.from_numpy(np.arctan2(self.center - y, x - self.center))
        self.device = params.device


    # Gamma desnities for the mixture distribution
    def gamma_density(self, x, shape, scale):
        # Ensure all tensors are on the same device as x
        shape = shape.to(x.device)
        scale = scale.to(x.device)
        return -1*torch.lgamma(shape) -shape*torch.log(scale) + (shape - 1)*torch.log(x) - x/scale

    # von Mises densities for the mixture distribution
    def vonmises_density(self, x, kappa, vm_mu):
        # Ensure all tensors are on the same device as x
        kappa = kappa.to(x.device)
        vm_mu = vm_mu.to(x.device)
        return kappa*torch.cos(x - vm_mu) - 1*(np.log(2*torch.pi) + torch.log(torch.special.i0(kappa)))


    def forward(self, x, bearing):

        # parameters of the first mixture distribution
        gamma_shape1 = torch.exp(x[:, 0]).unsqueeze(0).unsqueeze(0)
        gamma_shape1 = gamma_shape1.repeat(self.image_dim, self.image_dim, 1)
        gamma_shape1 = gamma_shape1.permute(2, 0, 1)

        gamma_scale1 = torch.exp(x[:, 1]).unsqueeze(0).unsqueeze(0)
        gamma_scale1 = gamma_scale1.repeat(self.image_dim, self.image_dim, 1)
        gamma_scale1 = gamma_scale1.permute(2, 0, 1)

        gamma_weight1 = torch.exp(x[:, 2]).unsqueeze(0).unsqueeze(0)
        gamma_weight1 = gamma_weight1.repeat(self.image_dim, self.image_dim, 1)
        gamma_weight1 = gamma_weight1.permute(2, 0, 1)

        # parameters of the second mixture distribution
        gamma_shape2 = torch.exp(x[:, 3]).unsqueeze(0).unsqueeze(0)
        gamma_shape2 = gamma_shape2.repeat(self.image_dim, self.image_dim, 1)
        gamma_shape2 = gamma_shape2.permute(2, 0, 1)

        gamma_scale2 = torch.exp(x[:, 4]).unsqueeze(0).unsqueeze(0)
        gamma_scale2 = gamma_scale2 * 500 ### to transform the scale parameter to be near 1
        gamma_scale2 = gamma_scale2.repeat(self.image_dim, self.image_dim, 1)
        gamma_scale2 = gamma_scale2.permute(2, 0, 1)

        gamma_weight2 = torch.exp(x[:, 5]).unsqueeze(0).unsqueeze(0)
        gamma_weight2 = gamma_weight2.repeat(self.image_dim, self.image_dim, 1)
        gamma_weight2 = gamma_weight2.permute(2, 0, 1)

        # Apply softmax to the weights
        gamma_weights = torch.stack([gamma_weight1, gamma_weight2], dim=0)
        gamma_weights = torch.nn.functional.softmax(gamma_weights, dim=0)
        gamma_weight1 = gamma_weights[0]
        gamma_weight2 = gamma_weights[1]

        # calculation of Gamma densities
        gamma_density_layer1 = self.gamma_density(self.distance_layer, gamma_shape1, gamma_scale1).to(device)
        gamma_density_layer2 = self.gamma_density(self.distance_layer, gamma_shape2, gamma_scale2).to(device)

        # combining both densities to create a mixture distribution using logsumexp
        logsumexp_gamma_corr = torch.max(gamma_density_layer1, gamma_density_layer2)
        gamma_density_layer = logsumexp_gamma_corr + torch.log(gamma_weight1 * torch.exp(gamma_density_layer1 - logsumexp_gamma_corr) + gamma_weight2 * torch.exp(gamma_density_layer2 - logsumexp_gamma_corr))
        # print(torch.sum(gamma_density_layer))
        # print(torch.sum(torch.exp(gamma_density_layer)))

        # normalise the gamma weights so they sum to 1
        # gamma_density_layer = gamma_density_layer - torch.logsumexp(gamma_density_layer, dim = (1, 2), keepdim = True)

        # print(torch.sum(gamma_density_layer))
        # print(torch.sum(torch.exp(gamma_density_layer)))


        ## Von Mises Distributions

        # calculate the new bearing from the turning angle
        # takes in the bearing from the previous step and adds the turning angle
        bearing_new1 = x[:, 6] + bearing[:, 0]
        # bearing_new1 = 0.0 + bearing[:, 0]

        # print('Bearing.shape ', bearing.shape)
        # print('Bearing[:, 0].shape ', bearing[:, 0].shape)
        # print('Bearing[:, 0] ', bearing[:, 0])
        # the new bearing becomes the mean of the von Mises distribution
        # the estimated parameter [x:, 7] is the turning angle of the next step
        # which is always in reference to the input bearing
        vonmises_mu1 = bearing_new1.unsqueeze(0).unsqueeze(0)
        vonmises_mu1 = vonmises_mu1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_mu1 = vonmises_mu1.permute(2, 0, 1)

        # parameters of the first von Mises distribution
        vonmises_kappa1 = torch.exp(x[:, 7]).unsqueeze(0).unsqueeze(0)
        vonmises_kappa1 = vonmises_kappa1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_kappa1 = vonmises_kappa1.permute(2, 0, 1)

        vonmises_weight1 = torch.exp(x[:, 8]).unsqueeze(0).unsqueeze(0)
        vonmises_weight1 = vonmises_weight1.repeat(self.image_dim, self.image_dim, 1)
        vonmises_weight1 = vonmises_weight1.permute(2, 0, 1)

        # vm_mu and weight for the second von Mises distribution
        bearing_new2 = x[:, 9] + bearing[:, 0]
        # bearing_new2 = torch.pi + bearing[:, 0]
        # bearing_new2 = -torch.pi + bearing[:, 0]

        vonmises_mu2 = bearing_new2.unsqueeze(0).unsqueeze(0)
        vonmises_mu2 = vonmises_mu2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_mu2 = vonmises_mu2.permute(2, 0, 1)

        # parameters of the second von Mises distribution
        vonmises_kappa2 = torch.exp(x[:, 10]).unsqueeze(0).unsqueeze(0)
        vonmises_kappa2 = vonmises_kappa2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_kappa2 = vonmises_kappa2.permute(2, 0, 1)

        vonmises_weight2 = torch.exp(x[:, 11]).unsqueeze(0).unsqueeze(0)
        vonmises_weight2 = vonmises_weight2.repeat(self.image_dim, self.image_dim, 1)
        vonmises_weight2 = vonmises_weight2.permute(2, 0, 1)

        # Apply softmax to the weights
        vonmises_weights = torch.stack([vonmises_weight1, vonmises_weight2], dim=0)
        vonmises_weights = torch.nn.functional.softmax(vonmises_weights, dim=0)
        vonmises_weight1 = vonmises_weights[0]
        vonmises_weight2 = vonmises_weights[1]

        # calculation of von Mises densities
        vonmises_density_layer1 = self.vonmises_density(self.bearing_layer, vonmises_kappa1, vonmises_mu1).to(device)
        vonmises_density_layer2 = self.vonmises_density(self.bearing_layer, vonmises_kappa2, vonmises_mu2).to(device)

        # combining both densities to create a mixture distribution using the logsumexp trick
        logsumexp_vm_corr = torch.max(vonmises_density_layer1, vonmises_density_layer2)
        vonmises_density_layer = logsumexp_vm_corr + torch.log(vonmises_weight1 * torch.exp(vonmises_density_layer1 - logsumexp_vm_corr) + vonmises_weight2 * torch.exp(vonmises_density_layer2 - logsumexp_vm_corr))
        # print(torch.sum(vonmises_density_layer))
        # print(torch.sum(torch.exp(vonmises_density_layer)))

        # normalise so the densities sum to 1 when exponentiated
        # vonmises_density_layer = vonmises_density_layer - torch.logsumexp(vonmises_density_layer, dim = (1, 2), keepdim = True)
        # print(torch.sum(vonmises_density_layer))
        # print(torch.sum(torch.exp(vonmises_density_layer)))

        # combining the two distributions
        movement_grid = gamma_density_layer + vonmises_density_layer # Gamma and von Mises densities are on the log-scale
        # print('Movement grid ', torch.sum(movement_grid))
        # print(torch.sum(torch.exp(movement_grid)))

        # normalise before combining with the habitat predictions
        movement_grid = movement_grid - torch.logsumexp(movement_grid, dim = (1, 2), keepdim = True)

        # print('Movement grid norm ', torch.sum(movement_grid))
        # print(torch.sum(torch.exp(movement_grid)))

        return movement_grid


class Scalar_to_Grid_Block(nn.Module):
    def __init__(self, params):
        super(Scalar_to_Grid_Block, self).__init__()
        self.batch_size = params.batch_size
        self.image_dim = params.image_dim
        self.device = params.device

    def forward(self, x):
        num_scalars = x.shape[1]
        scalar_map = x.view(x.shape[0], num_scalars, 1, 1).expand(x.shape[0], num_scalars, self.image_dim, self.image_dim)
        return scalar_map


class Vector_to_Grid_Block(nn.Module):
    def __init__(self, params):
        super(Vector_to_Grid_Block, self).__init__()
        self.batch_size = params.batch_size
        self.image_dim = params.image_dim
        self.device = params.device

    def forward(self, x):
        x_unnorm = x.reshape(x.shape[0], self.image_dim, self.image_dim)
        x = x_unnorm - torch.logsumexp(x_unnorm, dim = (1, 2), keepdim = True)
        return x

        # x = x_unnorm/torch.sum(x_unnorm)
        # return x


class ConvJointModel(nn.Module):
    def __init__(self, params):
        super(ConvJointModel, self).__init__()
        # self.conv_habitat = Conv2d_block(params)
        # self.fcn_habitat_all = FCN_block_all_habitat(params)
        # self.fcn_habitat_nonspatial = FCN_block_nonspatial(params)
        # self.habitat_grid_output = Vector_to_Grid_Block(params)

        self.scalar_grid_output = Scalar_to_Grid_Block(params)
        self.conv_habitat = Conv2d_block_spatial(params)

        self.conv_movement = Conv2d_block_toFC(params)
        self.fcn_movement_all = FCN_block_all_movement(params)
        self.fcn_movement_nonspatial = FCN_block_nonspatial(params)
        self.movement_grid_output = Params_to_Grid_Block(params)
        self.device = params.device

    def forward(self, x):
        spatial_data_x = x[0]
        scalars_to_grid = x[1]
        # additional_data_x = x[2]
        bearing_x = x[2]

        # conv_habitat = self.conv_habitat(spatial_data_x)
        # covariates_habitat = self.fcn_habitat_nonspatial(additional_data_x)
        # all_predictors_habitat = torch.cat([conv_habitat, covariates_habitat], dim = 1)
        # # print(f"Shape after concatenation: {all_predictors_habitat.shape}")  # Debugging print
        # output_habitat = self.fcn_habitat_all(all_predictors_habitat)
        # output_habitat = self.habitat_grid_output(output_habitat)

        # SCALAR GRIDS
        scalar_grids = self.scalar_grid_output(scalars_to_grid)
        all_spatial = torch.cat([spatial_data_x, scalar_grids], dim = 1)
        # print(f"Shape after scalar grid: {all_spatial.shape}")  # Debugging print


        # HABITAT SELECTION
        output_habitat = self.conv_habitat(all_spatial)
        # print(f"Shape after CNN habitat: {output_habitat.shape}")  # Debugging print


        # MOVEMENT
        conv_movement = self.conv_movement(all_spatial)
        # print(f"Shape after CNN to FC movement: {conv_movement.shape}")  # Debugging print

        # covariates_movement = self.fcn_movement_nonspatial(additional_data_x)
        # print(f"Shape after fcn_movement_nonspatial: {covariates_movement.shape}")  # Debugging print

        # all_predictors_movement = torch.cat([conv_movement, covariates_movement], dim = 1)
        # print(f"Shape after torch.cat([conv_movement, covariates_movement], dim = 1): {all_predictors_movement.shape}")  # Debugging print

        output_movement = self.fcn_movement_all(conv_movement)
        # print(f"Shape after fcn_movement_all: {output_movement.shape}")  # Debugging print
        output_movement = self.movement_grid_output(output_movement, bearing_x)
        # print(f"Shape after CNN movement: {output_movement.shape}")  # Debugging print

        # combine the habitat and movement predictions
        output = torch.stack((output_habitat, output_movement), dim = -1)
        return output


class ModelParams():
    def __init__(self, dict_params):
        self.batch_size = dict_params["batch_size"]
        self.image_dim = dict_params["image_dim"]
        self.pixel_size = dict_params["pixel_size"]
        self.batch_size = dict_params["batch_size"]
        self.dim_in_nonspatial_to_grid = dict_params["dim_in_nonspatial_to_grid"]
        self.dense_dim_in_nonspatial = dict_params["dense_dim_in_nonspatial"]
        self.dense_dim_hidden = dict_params["dense_dim_hidden"]
        self.dense_dim_out = dict_params["dense_dim_out"]
        self.batch_size = dict_params["batch_size"]
        self.dense_dim_in_all = dict_params["dense_dim_in_all"]
        self.dense_dim_hidden = dict_params["dense_dim_hidden"]
        self.dense_dim_out = dict_params["dense_dim_out"]
        self.batch_size = dict_params["batch_size"]
        self.input_channels = dict_params["input_channels"]
        self.output_channels = dict_params["output_channels"]
        self.kernel_size = dict_params["kernel_size"]
        self.stride = dict_params["stride"]
        self.kernel_size_mp = dict_params["kernel_size_mp"]
        self.stride_mp = dict_params["stride_mp"]
        self.padding = dict_params["padding"]
        self.image_dim = dict_params["image_dim"]
        self.num_movement_params = dict_params["num_movement_params"]
        self.dropout = dict_params["dropout"]
        self.device = dict_params["device"]

Instantiate the model

Code
params_dict = {"batch_size": 32,
               "image_dim": 101, #number of pixels along the edge of each local patch/image
               "pixel_size": 25, #number of metres along the edge of a pixel
               "dim_in_nonspatial_to_grid": 4, #the number of scalar predictors that are converted to a grid and appended to the spatial features
               "dense_dim_in_nonspatial": 4, #change this to however many other scalar predictors you have (bearing, velocity etc)
               "dense_dim_hidden": 128, #number of nodes in the hidden layers
               "dense_dim_out": 128, #number of nodes in the output of the fully connected block (FCN)
               "dense_dim_in_all": 2500,# + 128, #number of inputs entering the fully connected block once the nonspatial features have been concatenated to the spatial features
               "input_channels": 4 + 4, #number of spatial layers in each image + number of scalar layers that are converted to a grid
               "output_channels": 4, #number of filters to learn
               "kernel_size": 3, #the size of the 2D moving windows / kernels that are being learned
               "stride": 1, #the stride used when applying the kernel.  This reduces the dimension of the output if set to greater than 1
               "kernel_size_mp": 2, #the size of the kernel that is used in max pooling operations
               "stride_mp": 2, #the stride that is used in max pooling operations
               "padding": 1, #the amount of padding to apply to images prior to applying the 2D convolution
               "num_movement_params": 12, #number of parameters used to parameterise the movement kernel
               "dropout": 0.1,
               "device": device
               }

params = ModelParams(params_dict)
model = ConvJointModel(params).to(device)
# print(model)

Load the model weights

Code
# date of the trained model checkpoint
date = '2024-11-12'

# # load the model weights
# print(model.state_dict())
model.load_state_dict(torch.load(f'model_checkpoints/checkpoint_CNN_global_buffalo{buffalo_id}_TAmix_bearing-rev_adj_hab-covs_grid-only_{date}.pt', map_location=torch.device('cpu')))
# print(model.state_dict())
# model.eval()
C:\Users\for329\AppData\Local\Temp\ipykernel_22308\338321716.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  model.load_state_dict(torch.load(f'model_checkpoints/checkpoint_CNN_global_buffalo{buffalo_id}_TAmix_bearing-rev_adj_hab-covs_grid-only_{date}.pt', map_location=torch.device('cpu')))
<All keys matched successfully>

Setup validation

Code
# Create a mapping from day of the year to month index
def day_to_month_index(day_of_year):
    # Calculate the year and the day within that year
    base_date = datetime(2018, 1, 1)
    date = base_date + timedelta(days=int(day_of_year) - 1)
    year_diff = date.year - base_date.year
    month_index = (date.month - 1) + (year_diff * 12)  # month index (0-based, accounting for year change)
    return month_index

yday = 370 # NOT THE MONTH (due to 0 indexing) - but the index of the of month of the year to extract the appropriate monthly NDVI
month_index = day_to_month_index(yday) 
print(month_index)
12
Code
window_size = 101

# starting location of buffalo 2005
x = 41969.310875 
y = -1.435671e+06

yday = 280
month_index = day_to_month_index(yday)

ndvi_subset, origin_x, origin_y = subset_raster_with_padding_torch(ndvi_global_norm[month_index,:,:], x, y, window_size, raster_transform)
canopy_subset, origin_x, origin_y = subset_raster_with_padding_torch(canopy_global_norm, x, y, window_size, raster_transform)
herby_subset, origin_x, origin_y = subset_raster_with_padding_torch(herby_global_norm, x, y, window_size, raster_transform)
slope_subset, origin_x, origin_y = subset_raster_with_padding_torch(slope_global_norm, x, y, window_size, raster_transform)

print(origin_x, origin_y)

# Plot the subset
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

axs[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis')
axs[0, 0].set_title('NDVI')
# axs[0, 0].axis('off')

axs[0, 1].imshow(canopy_subset.numpy(), cmap='viridis')
axs[0, 1].set_title('Canopy Cover')
# axs[0, 1].axis('off')

axs[1, 0].imshow(herby_subset.numpy(), cmap='viridis')
axs[1, 0].set_title('Herbaceous Vegetation')
# axs[1, 0].axis('off')

axs[1, 1].imshow(slope_subset.numpy(), cmap='viridis')
axs[1, 1].set_title('Slope')
# axs[1, 1].axis('off')
1628 1136
Text(0.5, 1.0, 'Slope')

Next-step probability values

Code
# test_data = buffalo_df.iloc[0:10]
test_data = buffalo_df
n = len(test_data)

# create empty vectors to store the predicted probabilities
habitat_probs = np.repeat(0., len(buffalo_df))
move_probs = np.repeat(0., len(buffalo_df))
next_step_probs = np.repeat(0., len(buffalo_df))

# start at 1 so the bearing at t - 1 is available
for i in range(1, n):
  
  sample_tm1 = test_data.iloc[i-1] # get the step at t - 1 for the bearing of the approaching step
  sample = test_data.iloc[i]

  # current location (x1, y1)
  x = sample['x1_']
  y = sample['y1_']
  # print('x and y are ', x,y)

  # Convert geographic coordinates to pixel coordinates
  px, py = ~raster_transform * (x, y)
  # Convert geographic coordinates to pixel coordinates
  # new_x, new_y = raster_transform * (px, py)
  # print('px and py are ', px, py)

  # next step location (x2, y2)
  x2 = sample['x2_']
  y2 = sample['y2_']
  # print('x2 and y2 are ', x2, y2)

  # Convert geographic coordinates to pixel coordinates
  px2, py2 = ~raster_transform * (x2, y2)
  # print('px2 and py2 are ', px2, py2)


  # # THE Y COORDINATE COMES FIRST in the sampled coordinates
  # new_px = origin_xs[0] + sampled_coordinates[1]
  # new_py = origin_ys[0] + sampled_coordinates[0]

  # # Convert geographic coordinates to pixel coordinates
  # new_x, new_y = raster_transform * (new_px, new_py)

  d_x = x2 - x
  d_y = y2 - y
  # print('d_x and d_y are ', d_x, d_y)
  

  # temporal covariates
  hour_t2_sin = sample['hour_t2_sin']
  hour_t2_cos = sample['hour_t2_cos']
  yday_t2_sin = sample['yday_t2_sin']
  yday_t2_cos = sample['yday_t2_cos']
  # print(hour_t2_sin, hour_t2_cos, yday_t2_sin, yday_t2_cos)

  # bearing of PREVIOUS step
  bearing = sample_tm1['bearing']
  # print(bearing)

  # yday = sample['yday_t2_base_2018']
  yday = sample['yday_t2']
  # print(yday)

  month_index = day_to_month_index(yday)
  # print(month_index)

  ndvi_subset, origin_x, origin_y = subset_raster_with_padding_torch(ndvi_global_norm[month_index,:,:], x, y, window_size, raster_transform)
  canopy_subset, origin_x, origin_y = subset_raster_with_padding_torch(canopy_global_norm, x, y, window_size, raster_transform)
  herby_subset, origin_x, origin_y = subset_raster_with_padding_torch(herby_global_norm, x, y, window_size, raster_transform)
  slope_subset, origin_x, origin_y = subset_raster_with_padding_torch(slope_global_norm, x, y, window_size, raster_transform)

  # print('origin_x and origin_y are ', origin_x, origin_y)

  px2_subset = px2 - origin_x
  py2_subset = py2 - origin_y
  # print('delta origin_x and origin_y are ', px2_subset, py2_subset)
  # print('delta origin_x and origin_y are ', int(px2_subset), int(py2_subset))

  # Extract the value of the covariates at the location of x2, y2
  value = ndvi_subset.detach().cpu().numpy()[(int(py2_subset), int(px2_subset))]

  # # from the stack of local layers
  # ndvi_max = 0.8220
  # ndvi_min = -0.2772
  # # return to natural scale
  # value_natural = value * (ndvi_max - ndvi_min) + ndvi_min

  # print('Value natural scale = ', value_natural)
  

  # # Plot the subset covariates
  # fig_covs, axs_covs = plt.subplots(2, 2, figsize=(10, 10))

  # # NDVI
  # axs_covs[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis')
  # axs_covs[0, 0].set_title('NDVI')
  # fig_covs.colorbar(axs_covs[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis'), ax=axs_covs[0, 0], shrink=0.7)
  # # Canopy Cover
  # axs_covs[0, 1].imshow(canopy_subset.numpy(), cmap='viridis')
  # axs_covs[0, 1].set_title('Canopy Cover')
  # fig_covs.colorbar(axs_covs[0, 1].imshow(canopy_subset.numpy(), cmap='viridis'), ax=axs_covs[0, 1], shrink=0.7)
  # # Herbaceous Vegetation
  # axs_covs[1, 0].imshow(herby_subset.numpy(), cmap='viridis')
  # axs_covs[1, 0].set_title('Herbaceous Vegetation')
  # fig_covs.colorbar(axs_covs[1, 0].imshow(herby_subset.numpy(), cmap='viridis'), ax=axs_covs[1, 0], shrink=0.7)
  # # Slope
  # axs_covs[1, 1].imshow(slope_subset.numpy(), cmap='viridis')
  # axs_covs[1, 1].set_title('Slope')  
  # fig_covs.colorbar(axs_covs[1, 1].imshow(slope_subset.numpy(), cmap='viridis'), ax=axs_covs[1, 1], shrink=0.7)


  # Stack the channels along a new axis; here, 1 is commonly used for channel axis in PyTorch
  x1 = torch.stack([ndvi_subset, canopy_subset, herby_subset, slope_subset], dim=0)
  x1 = x1.unsqueeze(0)
  # print(x1.shape)

  # Convert lists to PyTorch tensors
  hour_t2_sin_tensor = torch.tensor(hour_t2_sin).float()
  hour_t2_cos_tensor = torch.tensor(hour_t2_cos).float()
  yday_t2_sin_tensor = torch.tensor(yday_t2_sin).float()
  yday_t2_cos_tensor = torch.tensor(yday_t2_cos).float()

  # Stack tensors column-wise
  x2 = torch.stack((hour_t2_sin_tensor.unsqueeze(0), 
                    hour_t2_cos_tensor.unsqueeze(0), 
                    yday_t2_sin_tensor.unsqueeze(0), 
                    yday_t2_cos_tensor.unsqueeze(0)),  
                    dim=1)
  # print(x2)
  # print(x2.shape)

  # put bearing in the correct dimension (batch_size, 1)
  bearing = torch.tensor(bearing).float().unsqueeze(0).unsqueeze(0)
  # print(bearing)
  # print(bearing.shape)

  test = model((x1, x2, bearing))

  # plot the results of the habitat density as an image
  hab_density = test.detach().cpu().numpy()[0,:,:,0]
  hab_density_exp = np.exp(hab_density)
  # print(np.sum(hab_density_exp))

  # Store the probability of habitat selection at the location of x2, y2
  habitat_probs[i] = hab_density_exp[(int(py2_subset), int(px2_subset))]
  # print('Habitat probability value = ', habitat_probs[i])

  # Create the mask for x and y coordinates
  x_mask = np.ones_like(hab_density)
  y_mask = np.ones_like(hab_density)

  # mask out cells on the edges that affect the colour scale
  x_mask[:, :3] = -np.inf
  x_mask[:, 98:] = -np.inf
  y_mask[:3, :] = -np.inf
  y_mask[98:, :] = -np.inf

  hab_density_mask = hab_density * x_mask * y_mask
  hab_density_exp_mask = hab_density_exp * x_mask * y_mask

  # # plot the results of the habitat density as an image - in log scale
  # plt.imshow(hab_density_mask)
  # plt.colorbar()
  # plt.show()

  # # plot the results of the habitat density as an image - as probabilities
  # plt.imshow(hab_density_exp_mask)
  # plt.colorbar()
  # plt.show()

  # movement probability
  move_density = test.detach().cpu().numpy()[0,:,:,1]
  move_density_exp = np.exp(move_density)
  # print(np.sum(move_density_exp))

  # Store the probability of movement at the location of x2, y2
  move_probs[i] = move_density_exp[(int(py2_subset), int(px2_subset))]
  # print('Movement probability value = ', move_probs[i])

  move_density_mask = move_density * x_mask * y_mask
  move_density_exp_mask = move_density_exp * x_mask * y_mask

  # # plot the results of the movement density as an image - in log scale
  # plt.imshow(move_density_mask)
  # plt.colorbar()
  # plt.show()

  # # plot the results of the movement density as an image - as probabilities
  # plt.imshow(move_density_exp_mask)
  # plt.colorbar()
  # plt.show()

  # next step probability
  step_density = test[0, :, :, 0] + test[0, :, :, 1]
  step_density = step_density.detach().cpu().numpy()
  step_density_exp = np.exp(step_density)
  # print('Sum of step density exp = ', np.sum(step_density_exp))

  step_density_exp_norm = step_density_exp / np.sum(step_density_exp)
  # print('Sum of step density exp norm = ', np.sum(step_density_exp_norm))

  # Extract the value of the covariates at the location of x2, y2
  next_step_probs[i] = step_density_exp_norm[(int(py2_subset), int(px2_subset))]
  # print('Next-step probability value = ', next_step_probs[i])

  step_density_mask = step_density * x_mask * y_mask
  step_density_exp_norm_mask = step_density_exp_norm * x_mask * y_mask

  # # results of the habitat and movement densities
  # # log-scale
  # plt.imshow(step_density_mask)
  # plt.colorbar()
  # plt.show()

  # # exponentiated
  # plt.imshow(step_density_exp_mask)
  # plt.colorbar()
  # plt.show()


  # # Plot the outputs
  # fig_out, axs_out = plt.subplots(2, 2, figsize=(10, 10))

  # # Plot NDVI
  # im1 = axs_out[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis')
  # axs_out[0, 0].set_title('NDVI')
  # fig_out.colorbar(im1, ax=axs_out[0, 0], shrink=0.7)

  # # # Plot target
  # # im2 = axs[0, 1].imshow(labels.detach().cpu().numpy()[0,:,:], cmap='viridis')
  # # axs[0, 1].set_title('Observed next step')
  # # # fig.colorbar(im2, ax=axs[0, 1], shrink=0.7)

  # # # Plot slope
  # # im2 = axs[0, 1].imshow(x1.detach().cpu().numpy()[0,12,:,:], cmap='viridis')
  # # axs[0, 1].set_title('Slope')
  # # fig.colorbar(im2, ax=axs[0, 1], shrink=0.7)

  # # Plot habitat selection log-probability
  # im2 = axs_out[0, 1].imshow(hab_density_mask, cmap='viridis')
  # axs_out[0, 1].set_title('Habitat selection log-probability')
  # fig_out.colorbar(im2, ax=axs_out[0, 1], shrink=0.7)

  # # Movement density log-probability
  # im3 = axs_out[1, 0].imshow(move_density_mask, cmap='viridis')
  # axs_out[1, 0].set_title('Movement log-probability')
  # fig_out.colorbar(im3, ax=axs_out[1, 0], shrink=0.7)

  # # # Next-step log probability
  # # im4 = axs[1, 1].imshow(step_density_mask, cmap='viridis')
  # # axs[1, 1].set_title('Next-step log-probability')
  # # fig.colorbar(im4, ax=axs[1, 1], shrink=0.7)

  # # [(int(py2_subset), int(px2_subset))]
  # next_step_mask = np.ones_like(hab_density)
  # next_step_mask[int(py2_subset), int(px2_subset)] = -np.inf

  # # Next-step probability
  # im4 = axs_out[1, 1].imshow(step_density_exp_norm_mask * next_step_mask, cmap='viridis')
  # axs_out[1, 1].set_title('Next-step probability')
  # fig_out.colorbar(im4, ax=axs_out[1, 1], shrink=0.7)

  # # filename_covs = f'outputs/next_step_validation/id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}_bearing{bearing_degrees}_next_r{row}_c{column}.png'
  # # plt.tight_layout()
  # # plt.savefig(filename_covs, dpi=600, bbox_inches='tight')
  # # plt.show()
  # # plt.close()  # Close the figure to free memory
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[23], line 137
    133   bearing = torch.tensor(bearing).float().unsqueeze(0).unsqueeze(0)
    134   # print(bearing)
    135   # print(bearing.shape)
--> 137   test = model((x1, x2, bearing))
    139 # plot the results of the habitat density as an image
    140   hab_density = test.detach().cpu().numpy()[0,:,:,0]

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

Cell In[18], line 370, in ConvJointModel.forward(self, x)
    365 all_spatial = torch.cat([spatial_data_x, scalar_grids], dim = 1)
    366 # print(f"Shape after scalar grid: {all_spatial.shape}")  # Debugging print
    367 
    368 
    369 # HABITAT SELECTION
--> 370 output_habitat = self.conv_habitat(all_spatial)
    371 # print(f"Shape after CNN habitat: {output_habitat.shape}")  # Debugging print
    372 
    373 
    374 # MOVEMENT
    375 conv_movement = self.conv_movement(all_spatial)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

Cell In[18], line 52, in Conv2d_block_spatial.forward(self, x)
     50 def forward(self, x):
     51     # print("Shape before squeeze:", self.conv2d(x).shape)
---> 52     conv2d_spatial = self.conv2d(x).squeeze(dim = 1)
     53     # print("Shape before logsumexp:", conv2d_spatial.shape)
     54 
     55     # normalise before combining with the movement grid
     56     conv2d_spatial = conv2d_spatial - torch.logsumexp(conv2d_spatial, dim = (1, 2), keepdim = True)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\modules\activation.py:133, in ReLU.forward(self, input)
    132 def forward(self, input: Tensor) -> Tensor:
--> 133     return F.relu(input, inplace=self.inplace)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\torch\nn\functional.py:1704, in relu(input, inplace)
   1702     result = torch.relu_(input)
   1703 else:
-> 1704     result = torch.relu(input)
   1705 return result

KeyboardInterrupt: 
Code
print(next_step_probs)
[0.         0.0030911  0.00460396 ... 0.         0.         0.        ]
Code
# Plot the habitat probs through time as a line graph
plt.plot(np.log(habitat_probs[habitat_probs>0]))
plt.xlabel('Time')
plt.ylabel('Probability')
plt.title('Habitat Probability')
plt.show()

# Plot the movement probs through time as a line graph
plt.plot(np.log(move_probs[move_probs>0]))
plt.xlabel('Time')
plt.ylabel('Probability')
plt.title('Movement Probability')
plt.show()

# Plot the next_step probs through time as a line graph
plt.plot(np.log(next_step_probs[next_step_probs>0]))
plt.xlabel('Time')
plt.ylabel('Probability')
plt.title('Next-step Probability')
plt.show()

Code
# Append the probabilities to the dataframe
buffalo_df['habitat_probs'] = habitat_probs
buffalo_df['move_probs'] = move_probs
buffalo_df['next_step_probs'] = next_step_probs
Code
csv_filename = f'../outputs/next_step_probs_TAmix_id{buffalo_id}_{today_date}.csv'

print(csv_filename)
# buffalo_df.to_csv(csv_filename, index=True)
../outputs/next_step_probs_TAmix_id2387_2024-12-06.csv

Loop the validation over each individual

Code
# select the id to train the model on
# buffalo_id = 2005
# n_samples = 10297

# buffalo_id = 2014
# n_samples = 6572

# buffalo_id = 2327
# n_samples = 8983

# buffalo_id = 2387
# n_samples = 10409

# Specify the path to your CSV file
csv_file_path = f'../buffalo_local_data_id/validation/validation_buffalo_{buffalo_id}_data_df_lag_1hr_n{n_samples}.csv'

# Read the CSV file into a DataFrame
buffalo_df = pd.read_csv(csv_file_path)

# Display the first few rows of the DataFrame
print(buffalo_df.head())

# Lag the values in column 'A' by one index
buffalo_df['bearing_tm1'] = buffalo_df['bearing'].shift(1)
# Pad the missing value with a specified value, e.g., 0
buffalo_df['bearing_tm1'] = buffalo_df['bearing_tm1'].fillna(0)

# Display the first few rows of the DataFrame
print(buffalo_df.head())
             x_            y_                    t_    id           x1_  \
0  41969.310875 -1.435671e+06  2018-07-25T01:04:23Z  2005  41969.310875   
1  41921.521939 -1.435654e+06  2018-07-25T02:04:39Z  2005  41921.521939   
2  41779.439594 -1.435601e+06  2018-07-25T03:04:17Z  2005  41779.439594   
3  41841.203272 -1.435635e+06  2018-07-25T04:04:39Z  2005  41841.203272   
4  41655.463332 -1.435604e+06  2018-07-25T05:04:27Z  2005  41655.463332   

            y1_           x2_           y2_     x2_cent    y2_cent  ...  \
0 -1.435671e+06  41921.521939 -1.435654e+06  -47.788936  16.857110  ...   
1 -1.435654e+06  41779.439594 -1.435601e+06 -142.082345  53.568427  ...   
2 -1.435601e+06  41841.203272 -1.435635e+06   61.763677 -34.322938  ...   
3 -1.435635e+06  41655.463332 -1.435604e+06 -185.739939  31.003534  ...   
4 -1.435604e+06  41618.651923 -1.435608e+06  -36.811409  -4.438037  ...   

         ta    cos_ta         x_min         x_max         y_min         y_max  \
0  1.367942  0.201466  40706.810875  43231.810875 -1.436934e+06 -1.434409e+06   
1 -0.021429  0.999770  40659.021939  43184.021939 -1.436917e+06 -1.434392e+06   
2  2.994917 -0.989262  40516.939594  43041.939594 -1.436863e+06 -1.434338e+06   
3 -2.799767 -0.942144  40578.703272  43103.703272 -1.436898e+06 -1.434373e+06   
4  0.285377  0.959556  40392.963332  42917.963332 -1.436867e+06 -1.434342e+06   

   s2_index  points_vect_cent  year_t2  yday_t2_2018_base  
0         7               NaN     2018                206  
1         7               NaN     2018                206  
2         7               NaN     2018                206  
3         7               NaN     2018                206  
4         7               NaN     2018                206  

[5 rows x 35 columns]
             x_            y_                    t_    id           x1_  \
0  41969.310875 -1.435671e+06  2018-07-25T01:04:23Z  2005  41969.310875   
1  41921.521939 -1.435654e+06  2018-07-25T02:04:39Z  2005  41921.521939   
2  41779.439594 -1.435601e+06  2018-07-25T03:04:17Z  2005  41779.439594   
3  41841.203272 -1.435635e+06  2018-07-25T04:04:39Z  2005  41841.203272   
4  41655.463332 -1.435604e+06  2018-07-25T05:04:27Z  2005  41655.463332   

            y1_           x2_           y2_     x2_cent    y2_cent  ...  \
0 -1.435671e+06  41921.521939 -1.435654e+06  -47.788936  16.857110  ...   
1 -1.435654e+06  41779.439594 -1.435601e+06 -142.082345  53.568427  ...   
2 -1.435601e+06  41841.203272 -1.435635e+06   61.763677 -34.322938  ...   
3 -1.435635e+06  41655.463332 -1.435604e+06 -185.739939  31.003534  ...   
4 -1.435604e+06  41618.651923 -1.435608e+06  -36.811409  -4.438037  ...   

     cos_ta         x_min         x_max         y_min         y_max  s2_index  \
0  0.201466  40706.810875  43231.810875 -1.436934e+06 -1.434409e+06         7   
1  0.999770  40659.021939  43184.021939 -1.436917e+06 -1.434392e+06         7   
2 -0.989262  40516.939594  43041.939594 -1.436863e+06 -1.434338e+06         7   
3 -0.942144  40578.703272  43103.703272 -1.436898e+06 -1.434373e+06         7   
4  0.959556  40392.963332  42917.963332 -1.436867e+06 -1.434342e+06         7   

   points_vect_cent  year_t2  yday_t2_2018_base  bearing_tm1  
0               NaN     2018                206     0.000000  
1               NaN     2018                206     2.802478  
2               NaN     2018                206     2.781049  
3               NaN     2018                206    -0.507220  
4               NaN     2018                206     2.976198  

[5 rows x 36 columns]
Code
# Step 1: Specify the directory containing your TIFF files
data_dir = '../buffalo_local_data_id/validation'  # Replace with the actual path to your TIFF files

# Step 2: Use glob to get a list of all TIFF files matching the pattern
csv_files = glob.glob(os.path.join(data_dir, 'validation_*.csv')) 
print(f'Found {len(csv_files)} CSV files')
print('\n'.join(csv_files))

# select only 2005
csv_files = [csv_files[0]]
print(csv_files)
Found 13 CSV files
../buffalo_local_data_id/validation\validation_buffalo_2005_data_df_lag_1hr_n10297.csv
../buffalo_local_data_id/validation\validation_buffalo_2014_data_df_lag_1hr_n6572.csv
../buffalo_local_data_id/validation\validation_buffalo_2018_data_df_lag_1hr_n9440.csv
../buffalo_local_data_id/validation\validation_buffalo_2021_data_df_lag_1hr_n6928.csv
../buffalo_local_data_id/validation\validation_buffalo_2022_data_df_lag_1hr_n9099.csv
../buffalo_local_data_id/validation\validation_buffalo_2024_data_df_lag_1hr_n9531.csv
../buffalo_local_data_id/validation\validation_buffalo_2039_data_df_lag_1hr_n5569.csv
../buffalo_local_data_id/validation\validation_buffalo_2154_data_df_lag_1hr_n10417.csv
../buffalo_local_data_id/validation\validation_buffalo_2158_data_df_lag_1hr_n9700.csv
../buffalo_local_data_id/validation\validation_buffalo_2223_data_df_lag_1hr_n5310.csv
../buffalo_local_data_id/validation\validation_buffalo_2327_data_df_lag_1hr_n8983.csv
../buffalo_local_data_id/validation\validation_buffalo_2387_data_df_lag_1hr_n10409.csv
../buffalo_local_data_id/validation\validation_buffalo_2393_data_df_lag_1hr_n5299.csv
['../buffalo_local_data_id/validation\\validation_buffalo_2005_data_df_lag_1hr_n10297.csv']
Code
for j in range(0, len(csv_files)):

    # Read the CSV file into a DataFrame
    buffalo_df = pd.read_csv(csv_files[j])

    # for saving the file
    buffalo_id = csv_files[j][55:59]

    # Display the first few rows of the DataFrame
    # print(buffalo_df.head())

    # Lag the values in column 'A' by one index
    buffalo_df['bearing_tm1'] = buffalo_df['bearing'].shift(1)
    # Pad the missing value with a specified value, e.g., 0
    buffalo_df['bearing_tm1'] = buffalo_df['bearing_tm1'].fillna(0)

    test_data = buffalo_df.iloc[7000:7500]
    # test_data = buffalo_df
    n = len(test_data)

    # create empty vectors to store the predicted probabilities
    habitat_probs = np.repeat(0., len(buffalo_df))
    move_probs = np.repeat(0., len(buffalo_df))
    next_step_probs = np.repeat(0., len(buffalo_df))

    # start at 1 so the bearing at t - 1 is available
    for i in range(1, n):
    
        sample_tm1 = test_data.iloc[i-1] # get the step at t - 1 for the bearing of the approaching step
        sample = test_data.iloc[i]

        # current location (x1, y1)
        x = sample['x1_']
        y = sample['y1_']
        # print('x and y are ', x,y)

        # Convert geographic coordinates to pixel coordinates
        px, py = ~raster_transform * (x, y)
        # Convert geographic coordinates to pixel coordinates
        # new_x, new_y = raster_transform * (px, py)
        # print('px and py are ', px, py)

        # next step location (x2, y2)
        x2 = sample['x2_']
        y2 = sample['y2_']
        # print('x2 and y2 are ', x2, y2)

        # Convert geographic coordinates to pixel coordinates
        px2, py2 = ~raster_transform * (x2, y2)
        # print('px2 and py2 are ', px2, py2)


        # # THE Y COORDINATE COMES FIRST in the sampled coordinates
        # new_px = origin_xs[0] + sampled_coordinates[1]
        # new_py = origin_ys[0] + sampled_coordinates[0]

        # # Convert geographic coordinates to pixel coordinates
        # new_x, new_y = raster_transform * (new_px, new_py)

        d_x = x2 - x
        d_y = y2 - y
        # print('d_x and d_y are ', d_x, d_y)
        

        # temporal covariates
        hour_t2_sin = sample['hour_t2_sin']
        hour_t2_cos = sample['hour_t2_cos']
        yday_t2_sin = sample['yday_t2_sin']
        yday_t2_cos = sample['yday_t2_cos']
        # print(hour_t2_sin, hour_t2_cos, yday_t2_sin, yday_t2_cos)

        hour_t2_integer = int(sample['hour_t2'])

        # bearing of PREVIOUS step
        bearing = sample_tm1['bearing']
        # print(bearing)

        yday = sample['yday_t2_2018_base']
        # yday = sample['yday_t2']
        # print(yday)
        yday_t2_integer = int(yday)

        month_index = day_to_month_index(yday)
        # print(month_index)

        ndvi_subset, origin_x, origin_y = subset_raster_with_padding_torch(ndvi_global_norm[month_index,:,:], x, y, window_size, raster_transform)
        canopy_subset, origin_x, origin_y = subset_raster_with_padding_torch(canopy_global_norm, x, y, window_size, raster_transform)
        herby_subset, origin_x, origin_y = subset_raster_with_padding_torch(herby_global_norm, x, y, window_size, raster_transform)
        slope_subset, origin_x, origin_y = subset_raster_with_padding_torch(slope_global_norm, x, y, window_size, raster_transform)

        # print('origin_x and origin_y are ', origin_x, origin_y)

        px2_subset = px2 - origin_x
        py2_subset = py2 - origin_y
        # print('delta origin_x and origin_y are ', px2_subset, py2_subset)
        # print('delta origin_x and origin_y are ', int(px2_subset), int(py2_subset))

        # Extract the value of the covariates at the location of x2, y2
        value = ndvi_subset.detach().cpu().numpy()[(int(py2_subset), int(px2_subset))]

        # # from the stack of local layers
        # ndvi_max = 0.8220
        # ndvi_min = -0.2772
        # # return to natural scale
        # value_natural = value * (ndvi_max - ndvi_min) + ndvi_min

        # print('Value natural scale = ', value_natural)
        

        # # Plot the subset covariates
        # fig_covs, axs_covs = plt.subplots(2, 2, figsize=(10, 10))

        # # NDVI
        # axs_covs[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis')
        # axs_covs[0, 0].set_title('NDVI')
        # fig_covs.colorbar(axs_covs[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis'), ax=axs_covs[0, 0], shrink=0.7)
        # # Canopy Cover
        # axs_covs[0, 1].imshow(canopy_subset.numpy(), cmap='viridis')
        # axs_covs[0, 1].set_title('Canopy Cover')
        # fig_covs.colorbar(axs_covs[0, 1].imshow(canopy_subset.numpy(), cmap='viridis'), ax=axs_covs[0, 1], shrink=0.7)
        # # Herbaceous Vegetation
        # axs_covs[1, 0].imshow(herby_subset.numpy(), cmap='viridis')
        # axs_covs[1, 0].set_title('Herbaceous Vegetation')
        # fig_covs.colorbar(axs_covs[1, 0].imshow(herby_subset.numpy(), cmap='viridis'), ax=axs_covs[1, 0], shrink=0.7)
        # # Slope
        # axs_covs[1, 1].imshow(slope_subset.numpy(), cmap='viridis')
        # axs_covs[1, 1].set_title('Slope')  
        # fig_covs.colorbar(axs_covs[1, 1].imshow(slope_subset.numpy(), cmap='viridis'), ax=axs_covs[1, 1], shrink=0.7)


        # Stack the channels along a new axis; here, 1 is commonly used for channel axis in PyTorch
        x1 = torch.stack([ndvi_subset, canopy_subset, herby_subset, slope_subset], dim=0)
        x1 = x1.unsqueeze(0)
        # print(x1.shape)

        # Convert lists to PyTorch tensors
        hour_t2_sin_tensor = torch.tensor(hour_t2_sin).float()
        hour_t2_cos_tensor = torch.tensor(hour_t2_cos).float()
        yday_t2_sin_tensor = torch.tensor(yday_t2_sin).float()
        yday_t2_cos_tensor = torch.tensor(yday_t2_cos).float()

        # Stack tensors column-wise
        x2 = torch.stack((hour_t2_sin_tensor.unsqueeze(0), 
                            hour_t2_cos_tensor.unsqueeze(0), 
                            yday_t2_sin_tensor.unsqueeze(0), 
                            yday_t2_cos_tensor.unsqueeze(0)),  
                            dim=1)
        # print(x2)
        # print(x2.shape)

        # put bearing in the correct dimension (batch_size, 1)
        bearing = torch.tensor(bearing).float().unsqueeze(0).unsqueeze(0)
        # print(bearing)
        # print(bearing.shape)

        test = model((x1, x2, bearing))

        # plot the results of the habitat density as an image
        hab_density = test.detach().cpu().numpy()[0,:,:,0]
        hab_density_exp = np.exp(hab_density)
        # print(np.sum(hab_density_exp))

        # Store the probability of habitat selection at the location of x2, y2
        habitat_probs[i] = hab_density_exp[(int(py2_subset), int(px2_subset))]
        # print('Habitat probability value = ', habitat_probs[i])

        # Create the mask for x and y coordinates
        x_mask = np.ones_like(hab_density)
        y_mask = np.ones_like(hab_density)

        # mask out cells on the edges that affect the colour scale
        x_mask[:, :3] = -np.inf
        x_mask[:, 98:] = -np.inf
        y_mask[:3, :] = -np.inf
        y_mask[98:, :] = -np.inf

        hab_density_mask = hab_density * x_mask * y_mask
        hab_density_exp_mask = hab_density_exp * x_mask * y_mask

        # # plot the results of the habitat density as an image - in log scale
        # plt.imshow(hab_density_mask)
        # plt.colorbar()
        # plt.show()

        # plot the results of the habitat density as an image - as probabilities
        plt.imshow(hab_density_exp_mask * next_step_mask)
        plt.colorbar()
        plt.show()

        # movement probability
        move_density = test.detach().cpu().numpy()[0,:,:,1]
        move_density_exp = np.exp(move_density)
        # print(np.sum(move_density_exp))

        # Store the probability of movement at the location of x2, y2
        move_probs[i] = move_density_exp[(int(py2_subset), int(px2_subset))]
        # print('Movement probability value = ', move_probs[i])

        move_density_mask = move_density * x_mask * y_mask
        move_density_exp_mask = move_density_exp * x_mask * y_mask

        # # plot the results of the movement density as an image - in log scale
        # plt.imshow(move_density_mask)
        # plt.colorbar()
        # plt.show()

        # # plot the results of the movement density as an image - as probabilities
        # plt.imshow(move_density_exp_mask)
        # plt.colorbar()
        # plt.show()

        # next step probability
        step_density = test[0, :, :, 0] + test[0, :, :, 1]
        step_density = step_density.detach().cpu().numpy()
        step_density_exp = np.exp(step_density)
        # print('Sum of step density exp = ', np.sum(step_density_exp))

        step_density_exp_norm = step_density_exp / np.sum(step_density_exp)
        # print('Sum of step density exp norm = ', np.sum(step_density_exp_norm))

        # Extract the value of the covariates at the location of x2, y2
        next_step_probs[i] = step_density_exp_norm[(int(py2_subset), int(px2_subset))]
        # print('Next-step probability value = ', next_step_probs[i])

        step_density_mask = step_density * x_mask * y_mask
        step_density_exp_norm_mask = step_density_exp_norm * x_mask * y_mask

        # # results of the habitat and movement densities
        # # log-scale
        # plt.imshow(step_density_mask)
        # plt.colorbar()
        # plt.show()

        # # exponentiated
        # plt.imshow(step_density_exp_mask)
        # plt.colorbar()
        # plt.show()

        # [(int(py2_subset), int(px2_subset))]
        next_step_mask = np.ones_like(hab_density)
        next_step_mask[int(py2_subset), int(px2_subset)] = -np.inf


        # Plot the outputs
        fig_out, axs_out = plt.subplots(2, 2, figsize=(10, 10))

        # Plot NDVI
        im1 = axs_out[0, 0].imshow(ndvi_subset.numpy(), cmap='viridis')
        axs_out[0, 0].set_title('NDVI')
        fig_out.colorbar(im1, ax=axs_out[0, 0], shrink=0.7)

        # # Plot target
        # im2 = axs[0, 1].imshow(labels.detach().cpu().numpy()[0,:,:], cmap='viridis')
        # axs[0, 1].set_title('Observed next step')
        # # fig.colorbar(im2, ax=axs[0, 1], shrink=0.7)

        # # Plot slope
        # im2 = axs[0, 1].imshow(x1.detach().cpu().numpy()[0,12,:,:], cmap='viridis')
        # axs[0, 1].set_title('Slope')
        # fig.colorbar(im2, ax=axs[0, 1], shrink=0.7)

        # # Plot habitat selection log-probability
        # im2 = axs_out[0, 1].imshow(hab_density_mask * next_step_mask, cmap='viridis')
        # axs_out[0, 1].set_title('Habitat selection log-probability')
        # fig_out.colorbar(im2, ax=axs_out[0, 1], shrink=0.7)

        # Plot habitat selection probability
        im2 = axs_out[0, 1].imshow(hab_density_exp_mask * next_step_mask, cmap='viridis')
        axs_out[0, 1].set_title('Habitat selection log-probability')
        fig_out.colorbar(im2, ax=axs_out[0, 1], shrink=0.7)

        # Movement density log-probability
        im3 = axs_out[1, 0].imshow(move_density_mask * next_step_mask, cmap='viridis')
        axs_out[1, 0].set_title('Movement log-probability')
        fig_out.colorbar(im3, ax=axs_out[1, 0], shrink=0.7)

        # # Next-step log probability
        # im4 = axs[1, 1].imshow(step_density_mask, cmap='viridis')
        # axs[1, 1].set_title('Next-step log-probability')
        # fig.colorbar(im4, ax=axs[1, 1], shrink=0.7)

        

        # Next-step probability
        im4 = axs_out[1, 1].imshow(step_density_exp_norm_mask * next_step_mask, cmap='viridis')
        axs_out[1, 1].set_title('Next-step probability')
        fig_out.colorbar(im4, ax=axs_out[1, 1], shrink=0.7)

        filename_covs = f'outputs/next_step_validation/id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}.png'
        plt.tight_layout()
        plt.savefig(filename_covs, dpi=600, bbox_inches='tight')
        # plt.show()
        plt.close()  # Close the figure to free memory

        
    # Plot the habitat probs through time as a line graph
    plt.plot(np.log(habitat_probs[habitat_probs>0]))
    plt.xlabel('Time')
    plt.ylabel('Probability')
    plt.title('Habitat Probability')
    plt.show()

    # Plot the movement probs through time as a line graph
    plt.plot(np.log(move_probs[move_probs>0]))
    plt.xlabel('Time')
    plt.ylabel('Probability')
    plt.title('Movement Probability')
    plt.show()

    # Plot the next_step probs through time as a line graph
    plt.plot(np.log(next_step_probs[next_step_probs>0]))
    plt.xlabel('Time')
    plt.ylabel('Probability')
    plt.title('Next-step Probability')
    plt.show()

    # Append the probabilities to the dataframe
    buffalo_df['habitat_probs'] = habitat_probs
    buffalo_df['move_probs'] = move_probs
    buffalo_df['next_step_probs'] = next_step_probs
    csv_filename = f'../outputs/next_step_probs_TAmix_id{buffalo_id}_{today_date}.csv'

    print(csv_filename)
    # buffalo_df.to_csv(csv_filename, index=True)

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[34], line 290
    287 fig_out.colorbar(im4, ax=axs_out[1, 1], shrink=0.7)
    289 filename_covs = f'outputs/next_step_validation/id{buffalo_id}_yday{yday_t2_integer}_hour{hour_t2_integer}.png'
--> 290 plt.tight_layout()
    291 plt.savefig(filename_covs, dpi=600, bbox_inches='tight')
    292 # plt.show()

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\pyplot.py:2801, in tight_layout(pad, h_pad, w_pad, rect)
   2793 @_copy_docstring_and_deprecators(Figure.tight_layout)
   2794 def tight_layout(
   2795     *,
   (...)
   2799     rect: tuple[float, float, float, float] | None = None,
   2800 ) -> None:
-> 2801     gcf().tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\figure.py:3545, in Figure.tight_layout(self, pad, h_pad, w_pad, rect)
   3543 previous_engine = self.get_layout_engine()
   3544 self.set_layout_engine(engine)
-> 3545 engine.execute(self)
   3546 if previous_engine is not None and not isinstance(
   3547     previous_engine, (TightLayoutEngine, PlaceHolderLayoutEngine)
   3548 ):
   3549     _api.warn_external('The figure layout has changed to tight')

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\layout_engine.py:183, in TightLayoutEngine.execute(self, fig)
    181 renderer = fig._get_renderer()
    182 with getattr(renderer, "_draw_disabled", nullcontext)():
--> 183     kwargs = get_tight_layout_figure(
    184         fig, fig.axes, get_subplotspec_list(fig.axes), renderer,
    185         pad=info['pad'], h_pad=info['h_pad'], w_pad=info['w_pad'],
    186         rect=info['rect'])
    187 if kwargs:
    188     fig.subplots_adjust(**kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\_tight_layout.py:266, in get_tight_layout_figure(fig, axes_list, subplotspec_list, renderer, pad, h_pad, w_pad, rect)
    261         return {}
    262     span_pairs.append((
    263         slice(ss.rowspan.start * div_row, ss.rowspan.stop * div_row),
    264         slice(ss.colspan.start * div_col, ss.colspan.stop * div_col)))
--> 266 kwargs = _auto_adjust_subplotpars(fig, renderer,
    267                                   shape=(max_nrows, max_ncols),
    268                                   span_pairs=span_pairs,
    269                                   subplot_list=subplot_list,
    270                                   ax_bbox_list=ax_bbox_list,
    271                                   pad=pad, h_pad=h_pad, w_pad=w_pad)
    273 # kwargs can be none if tight_layout fails...
    274 if rect is not None and kwargs is not None:
    275     # if rect is given, the whole subplots area (including
    276     # labels) will fit into the rect instead of the
   (...)
    280     # auto_adjust_subplotpars twice, where the second run
    281     # with adjusted rect parameters.

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\_tight_layout.py:82, in _auto_adjust_subplotpars(fig, renderer, shape, span_pairs, subplot_list, ax_bbox_list, pad, h_pad, w_pad, rect)
     80 for ax in subplots:
     81     if ax.get_visible():
---> 82         bb += [martist._get_tightbbox_for_layout_only(ax, renderer)]
     84 tight_bbox_raw = Bbox.union(bb)
     85 tight_bbox = fig.transFigure.inverted().transform_bbox(tight_bbox_raw)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\artist.py:1410, in _get_tightbbox_for_layout_only(obj, *args, **kwargs)
   1404 """
   1405 Matplotlib's `.Axes.get_tightbbox` and `.Axis.get_tightbbox` support a
   1406 *for_layout_only* kwarg; this helper tries to use the kwarg but skips it
   1407 when encountering third-party subclasses that do not support it.
   1408 """
   1409 try:
-> 1410     return obj.get_tightbbox(*args, **{**kwargs, "for_layout_only": True})
   1411 except TypeError:
   1412     return obj.get_tightbbox(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\_api\deprecation.py:457, in make_keyword_only.<locals>.wrapper(*args, **kwargs)
    451 if len(args) > name_idx:
    452     warn_deprecated(
    453         since, message="Passing the %(name)s %(obj_type)s "
    454         "positionally is deprecated since Matplotlib %(since)s; the "
    455         "parameter will become keyword-only %(removal)s.",
    456         name=name, obj_type=f"parameter of {func.__name__}()")
--> 457 return func(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\axes\_base.py:4476, in _AxesBase.get_tightbbox(self, renderer, call_axes_locator, bbox_extra_artists, for_layout_only)
   4474 for axis in self._axis_map.values():
   4475     if self.axison and axis.get_visible():
-> 4476         ba = martist._get_tightbbox_for_layout_only(axis, renderer)
   4477         if ba:
   4478             bb.append(ba)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\artist.py:1410, in _get_tightbbox_for_layout_only(obj, *args, **kwargs)
   1404 """
   1405 Matplotlib's `.Axes.get_tightbbox` and `.Axis.get_tightbbox` support a
   1406 *for_layout_only* kwarg; this helper tries to use the kwarg but skips it
   1407 when encountering third-party subclasses that do not support it.
   1408 """
   1409 try:
-> 1410     return obj.get_tightbbox(*args, **{**kwargs, "for_layout_only": True})
   1411 except TypeError:
   1412     return obj.get_tightbbox(*args, **kwargs)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\axis.py:1370, in Axis.get_tightbbox(self, renderer, for_layout_only)
   1368 if renderer is None:
   1369     renderer = self.figure._get_renderer()
-> 1370 ticks_to_draw = self._update_ticks()
   1372 self._update_label_position(renderer)
   1374 # go back to just this axis's tick labels

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\axis.py:1302, in Axis._update_ticks(self)
   1300 major_locs = self.get_majorticklocs()
   1301 major_labels = self.major.formatter.format_ticks(major_locs)
-> 1302 major_ticks = self.get_major_ticks(len(major_locs))
   1303 for tick, loc, label in zip(major_ticks, major_locs, major_labels):
   1304     tick.update_position(loc)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\axis.py:1672, in Axis.get_major_ticks(self, numticks)
   1670     tick = self._get_tick(major=True)
   1671     self.majorTicks.append(tick)
-> 1672     self._copy_tick_props(self.majorTicks[0], tick)
   1674 return self.majorTicks[:numticks]

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\axis.py:1618, in Axis._copy_tick_props(self, src, dest)
   1616 dest.label1.update_from(src.label1)
   1617 dest.label2.update_from(src.label2)
-> 1618 dest.tick1line.update_from(src.tick1line)
   1619 dest.tick2line.update_from(src.tick2line)
   1620 dest.gridline.update_from(src.gridline)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\lines.py:1357, in Line2D.update_from(self, other)
   1354 self._solidjoinstyle = other._solidjoinstyle
   1356 self._linestyle = other._linestyle
-> 1357 self._marker = MarkerStyle(marker=other._marker)
   1358 self._drawstyle = other._drawstyle

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\markers.py:248, in MarkerStyle.__init__(self, marker, fillstyle, transform, capstyle, joinstyle)
    246 self._user_joinstyle = JoinStyle(joinstyle) if joinstyle is not None else None
    247 self._set_fillstyle(fillstyle)
--> 248 self._set_marker(marker)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\markers.py:323, in MarkerStyle._set_marker(self, marker)
    321     self._marker_function = self._set_tuple_marker
    322 elif isinstance(marker, MarkerStyle):
--> 323     self.__dict__ = copy.deepcopy(marker.__dict__)
    324 else:
    325     try:

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:136, in deepcopy(x, memo, _nil)
    134 copier = _deepcopy_dispatch.get(cls)
    135 if copier is not None:
--> 136     y = copier(x, memo)
    137 else:
    138     if issubclass(cls, type):

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:221, in _deepcopy_dict(x, memo, deepcopy)
    219 memo[id(x)] = y
    220 for key, value in x.items():
--> 221     y[deepcopy(key, memo)] = deepcopy(value, memo)
    222 return y

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:143, in deepcopy(x, memo, _nil)
    141 copier = getattr(x, "__deepcopy__", None)
    142 if copier is not None:
--> 143     y = copier(memo)
    144 else:
    145     reductor = dispatch_table.get(cls)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\site-packages\matplotlib\path.py:285, in Path.__deepcopy__(self, memo)
    280 """
    281 Return a deepcopy of the `Path`.  The `Path` will not be
    282 readonly, even if the source `Path` is.
    283 """
    284 # Deepcopying arrays (vertices, codes) strips the writeable=False flag.
--> 285 p = copy.deepcopy(super(), memo)
    286 p._readonly = False
    287 return p

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:162, in deepcopy(x, memo, _nil)
    160                 y = x
    161             else:
--> 162                 y = _reconstruct(x, memo, *rv)
    164 # If is its own copy, don't memoize.
    165 if y is not x:

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:259, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
    257 if state is not None:
    258     if deep:
--> 259         state = deepcopy(state, memo)
    260     if hasattr(y, '__setstate__'):
    261         y.__setstate__(state)

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:136, in deepcopy(x, memo, _nil)
    134 copier = _deepcopy_dispatch.get(cls)
    135 if copier is not None:
--> 136     y = copier(x, memo)
    137 else:
    138     if issubclass(cls, type):

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:221, in _deepcopy_dict(x, memo, deepcopy)
    219 memo[id(x)] = y
    220 for key, value in x.items():
--> 221     y[deepcopy(key, memo)] = deepcopy(value, memo)
    222 return y

File c:\Users\for329\AppData\Local\miniconda3\envs\deep_ssf_env2\Lib\copy.py:143, in deepcopy(x, memo, _nil)
    141 copier = getattr(x, "__deepcopy__", None)
    142 if copier is not None:
--> 143     y = copier(memo)
    144 else:
    145     reductor = dispatch_table.get(cls)

KeyboardInterrupt: 

Code
value = -7.5
print(np.exp(value))

value = 6e-4
print(np.log(value))

# buffalo_df
0.0005530843701478336
-7.418580902748128